import os
import re
import random
import argparse
import json
import base64

from utils import LLMEngine
from coco import COCOLoader
from tqdm import tqdm


def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")


def parse_args():
    # Set up argparse to handle command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--task', type=str, default="identification", help='Task to be executed')
    parser.add_argument('--model_name', type=str, default="gpt5")
    parser.add_argument('--total_num', type=int, default=5000)
    parser.add_argument('--output_dir', type=str, default='output/object_identification', help='Output directory for generated data')
    parser.add_argument('--data_type', type=str, default='val2017', help='Data type to use for COCO loader')
    parser.add_argument('--random_seed', type=int, default=0, help='Random seed to use for shuffling')
    parser.add_argument('--prompt_dir', type=str, default='prompts', help='Directory containing prompt files')
    parser.add_argument('--prompt_type', type=str, default='image_input', help='Prompt type to use')
    parser.add_argument('--prompt_version_id', type=int, default=4, help='Prompt version id to use')
    return parser.parse_args()


def prepare_messages(args, coco, demo_coco, img_id, prompt, demo_ids=None):
    if args.prompt_type == 'text_input':
        captions = coco.load_caption(img_id)
        caption = '\n'.join(caption['caption'].strip() for caption in captions)
        messages = [{'role': 'user', 'content': prompt.replace('{caption}', caption)}]
    elif args.prompt_type == 'image_input':
        for demo_img_idx, demo_img_id in enumerate(demo_ids + [img_id]):
            demo_captions = coco.load_caption(demo_img_id) if demo_img_id == img_id else demo_coco.load_caption(demo_img_id)
            demo_caption = '\n'.join(caption['caption'].strip() for caption in demo_captions)
            caption_placeholder = f'{{caption_{demo_img_idx+1}}}'
            assert caption_placeholder in prompt, f"Caption placeholder for {demo_img_idx+1} not found in prompt"
            prompt = prompt.replace(caption_placeholder, demo_caption)

        message_content = [{"type": "text", "text": prompt}]
        for demo_img_id in demo_ids + [img_id]:
            demo_img_path = coco.get_img_path(demo_img_id) if demo_img_id == img_id else demo_coco.get_img_path(demo_img_id)
            demo_img_base64 = encode_image(demo_img_path)
            message_content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{demo_img_base64}"}})

        messages=[
            {
                "role": "user",
                "content": message_content,
            }
        ]
    else:
        raise ValueError(f"Invalid prompt type: {args.prompt_type}")

    return messages


def generate_object_identification(
    api_key,
):
    args = parse_args()

    # Initialize COCO loader for training data
    # coco = COCOLoader(dataType='train2017')
    coco = COCOLoader(dataType=args.data_type)
    demo_coco = COCOLoader(dataType='val2014')

    task = args.task
    total_num = args.total_num
    ImgIds = coco.getImgIds()

    random.seed(args.random_seed)
    random.shuffle(ImgIds)

    model_name_to_id = {
        'gpt5': 'gpt-5-2025-08-07',
        'gpt4.1': 'gpt-4.1-2025-04-14',
        'o4mini': 'o4-mini-2025-04-16',
    }
    model_id = model_name_to_id[args.model_name]
    llm = LLMEngine(model=model_id, api_key=api_key)

    prompt_file = os.path.join(args.prompt_dir, args.prompt_type, f'object_v{args.prompt_version_id}.txt')
    with open(prompt_file, 'r') as file:
        prompt = file.read()
    prompt_demo_file = os.path.join(args.prompt_dir, args.prompt_type, f'object_v{args.prompt_version_id}.json')
    with open(prompt_demo_file, 'r') as file:
        demo_ids = json.load(file)['img_ids']

    # Create output directory if it doesn't exist
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    # Create output file path
    output_file_path = os.path.join(args.output_dir, f'object_identification_{args.data_type}_{args.prompt_type}_{args.prompt_version_id}_{args.model_name}_{args.random_seed}_{total_num}.jsonl')

    # Open the file in append mode to write results immediately
    generated_count = 0
    with open(output_file_path, 'w') as output_file:
        for img_id in tqdm(ImgIds):
            if generated_count >= total_num:
                break
            if img_id in demo_ids:
                continue
            messages = prepare_messages(args, coco, demo_coco, img_id, prompt, demo_ids=demo_ids)
            response = llm.get_response_message(messages)
            try:
                for choice in response.choices:
                    # print(choice.message.content)
                    # raise Exception()
                    all_generated_data = choice.message.content.split('\n')
                    all_generated_data = [generated_data for generated_data in all_generated_data if generated_data != '']
                    # print(len(all_generated_data))
                    # print(all_generated_data)
                    for generated_data in all_generated_data:
                        generated_data = json.loads(generated_data)
                        detailed_generated_data = {
                            'img_id': img_id,
                            'question_type': generated_data['question_type'],
                            'first_hop_question': generated_data['first_hop_question'],
                            'first_hop_answer': generated_data['first_hop_answer'],
                            'first_hop_wrong_answers': generated_data['first_hop_wrong_answers'],
                            'second_hop_question_template': generated_data['second_hop_question_template'],
                            'full_question': generated_data['full_question'],
                            'second_hop_answer': generated_data['second_hop_answer'],
                            'second_hop_wrong_answers': generated_data['second_hop_wrong_answers'],
                        }
                        # print(detailed_generated_data)

                        # Write each result immediately to the file
                        json.dump(detailed_generated_data, output_file)
                        output_file.write('\n')
                        output_file.flush()  # Force write to disk

                        generated_count += 1
                    if generated_count >= total_num:
                        break
            except:
                continue
    print(f"Generated {generated_count} data items saved to {output_file_path}")


if __name__ == "__main__":
    generate_object_identification()
